Deep Learning CA1: CIFAR-10

In this notebook, I will try and develop a deep learning model to classify data points from the CIFAR-10 dataset.

I start by installing a few extra libraries which I will use later on in the notebook. The key library here is Tensorflow Addons, which extends Tensorflow and Keras to include implementations of more advanced deep learning advances.

Library Imports

I begin by importanting essential libraries. These are:

Setting Random Seed

To allow for improved reproducibility, I make sure to set a random seed. 42 is the answer to everything in the universe.

Problem Statement

Our problem statement is to develop a deep learning model that is capable of classifying images from the CIFAR-10 dataset.

What is CIFAR-10?

What are our goals for the model?

Our goals for the model are to develop a model that is able to generalize well to new data points (that is, does not overfit)

The main reason why I'm setting this goal is that while having an accurate model on the training set is nice, if we want to actually use the model in the real world, it ought to be able to generalize well to new data.

Optimizing Metrics (what I want to optimize)

Satsifying Metrics (what still needs to be satisfied)

Data Ingestion

We will begin by loading our data. Since a function to download the dataset is already included in Keras, we will make use of it to quickly load our data.

Data Splits

When training and evaluating our model, we will split our data into a training, validation, and testing set.

The training set will be used to train the model, the validation set for model tuning, and the testing set will be used to evaluate the final model, ensuring that it is able to generalize. (and does not overfit to the validation set as a result of our model tuning)

Split Size
Training 40K
Validation 10K
Testing 10k

In practice, this just means that when training our model, we will use 20% of our training data for validation. We choose a large number of examples for the validation set as we will be making our decisions based on the validation set, so it's important not to overfit to the validation set.

Each numbered label in the data set represents a specific article of clothing. To make the labels more readable, we will use a dictionary to map each number to the corresponding description.

Exploratory Data Analysis

Before we even start modelling, it's important to get a grip on the data. There are a few key questions to ask here:

Each image is a RGB 32x32 image. It is important to note that these images are not very large, meaning that the neural network does not need a particularly large receptive field.

Vizualising the Dataset

Let's take a look at a subset of random images first.

Let's take a look at what each class looks like.

We can see that the images are fairly diverse, with different viewing angles. We also note that some of the classes are quite broad. For example, we can see that along with real planes, toy planes are also included in the airplane class.

Visualization of Class Distribution

We can see that there is an even class balance. This means that we can make use of accuracy as our primary metric, as there is no real "minority" class, so accuracy is a good measure of classification performance.

What does the distribution of the images look like?

These are the average and standard deviation of pixel intensities on each color channel (Red, Blue, Green)

What is the "average" image?

Although the average image is fairly blurry, we can still roughly make out the image for the automobile, horse and truck. It is more difficult to make out the average image for the other classes, and could suggest that these classes may be more difficult to predict.

Basic Data Preprocessing

Before we model the data, it is important to do some basic pre-processing on it.

Encoding the Target Labels

As they are, the current labels are in a label encoded format. We will one hot encode the labels by using the to_categorical function from the Keras utilities.

From the one hot encoded label, we can use argmax to get back the original label in label encoded form.

Scaling and Normalizing the Inputs

From prior experimentation, I have found that normalizing the inputs helps in improves the accuracy of the resulting model as it converges faster. This is because the optimization algorithm we will be using, SGD, converges better when the feature scale is approximately the same.

Since we are normalizing the data, the resulting data will be centered around 0 with a standard deviation of 1, and thus we don't need to rescale the image beforehand.

Normalizing the inputs means that we will calculate the mean and standard deviation of the training set, and then apply the formula below

$$ X_{channel} = \frac{X_{channel} - μ_{channel}}{σ_{channel}} $$

Note that we prevent data leakage by ensuring we don't use any of the validation/testing data to calculate the per-channel mean and std. This is why we did our split into train-val-test beforehand.

This is what an image looks like after it has been normalized.

Basic Data Aug

Since our goal is to prevent overfitting, we also apply data augmentation. Data augmentation is a method to reduce the variance of a model by providing it with more training data. This is done here via doing random flips and crops to an image, to create variations on that image.

I picked these simple augmentations for the following reasons:

Nevertheless, I will try out a stronger data augmentation method later on as part of my experimentation.

Modelling

Once we've set up a pre-processing pipeline, we will begin training various models.

Optimization Algorithm

For all models, we will train them using stochastic gradient descent. This choice was made as SGD seems to generalize better on tasks like image classification, as compared to other optimizers like Adam. Since our primary focus is on building a robust model, it makes sense to use the SGD optimizer. More specifically, we will use SGD with Momentum.

In addition, I will make use of the Cosine Learning Rate Scheduler with Restarts which comes from the SGD with Warm Restarts paper. A learning rate scheduler decides how the learning rate should be adjusted during training.

This scheduler works by using a cosine wave to adjust the learning rate, and increasing the learning rate after one cycle, to simulate a restart of the training. This has a few advantages, the warm restart allows the model to escape bad local minima (the increased learning rate during a warm restart let's the model "jump" out of a bad minima and find a better one), and in the Snapshot Emsembling paper, it was shown that a copy of the model could be saved before each warm restart to cheaply ensemble together models (train 1, get M for free)

Choice of Hyperparameters

Training

Since each warm restart might cause a temporary degradation in performance, I increase the patience of my Early Stopping to give the model a better chance to recover.

Experiment Logger

To make the process of model training more standardized, I have created a class which will keep track of the various Models tested during the notebook.

I create a simple utility function to plot the loss and accuracy of the model as it trains.

To evaluate my model, I created a ModelEvaluator class to keep track of the different experiments conducted and save them offline. My doing this, it means that I can easily record my progress.

Baseline: Fully Connected Neural Network

As a simple baseline, I build a fully connected neural network.

newplot (22).pngnewplot (23).png

We can see that the overall performance of the model is poor, as the validation and training accuracy is poor.

To do better, I decide to move towards a CNN architecture. The main reason why I want to move towards a CNN architecture is because CNNs are well suited to the problem of image classification

Custom VGGNet Inspired Network (10 Conv)

To begin with, we will construct a simple CNN, loosely based off VGGNet.

Instead of just wholesale taking the VGG architecture, I downsize it and make certain improvements. The reason I do so is

So, here's what I've done:

The end result is a much smaller network than the original VGG networks

My implementation is forked from that given in Dive into Deep Learning

Without Data Augmentation

newplot (24).pngnewplot (25).png

By observing the learning curve, we can see that this baseline CNN begins overfitting to our data. As such, our attention now turns towards reducing the variance of our data.

With Basic Data Augmentation

newplot (26).pngnewplot (27).png

WideResNet

To improve upon our baseline, I make use of the Wide ResNet architecture. In summary, Wide ResNet is a Residual Network that goes for width instead of Depth. Although this would be expected to result in more overfitting, the overfitting can be counteracted by regularization, dropout and data augmentation. The benefits of going wide however are that:

Without Data Augmentation

newplot (28).pngnewplot (29).png

We can observe that Wide ResNet is able to train very well, reaching 93% validation accuracy in less than 60 epochs. However, we see that the model heavily overfits.

Model Improvement

CutMix

I observe that even with basic data augmentation, dropout and L2 regularization, the model still heavily overfits.

To counteract this overfitting, I employ a SOTA data augmentation method known as CutMix.

It works by cutting out portions of other images, and pasting them over another image. In additon, the label will be modified to reflect that the image contains two classes. This helps the model learn to identify an object from a partial view of it, improving generalization.

Implementation adapted from: https://keras.io/examples/vision/cutmix/

We stop the training at 100 epochs since the model appears to stop improving.

Impressively, Cutmix seems to allow us to reach the same level of performance, while greatly reducing overfitting.

Squeeze and Excite Blocks

Paper

newplot (30).pngnewplot (31).png

It can be seen that WRN with SE and Cutmix does very well, managing to avoid overfitting, while achieving a high validation accuracy of 94.6%

Mish Activation

As an additional improvement, I try changing the activation function with the Mish activation. The Mish activation is a proposed activation function defined as $$ f(z) = z * \tanh (softplus (z)) $$

It has been shown to provide improvements over ReLU over many benchmark datasets, at some cost the time taken for training

It seems that Mish activation does not help our model at all, so we stick to ReLU activation.

Snapshot Emsembling of the Hypertuned Model

By following the procedure in Snapshot Ensembles: Train 1, get M for free, I can make use of the model weights before each warm restart for an ensemble model. The logic is that the warm restart scheduling means that each model snapshot before the warm restart will be at a different local minima, and so by taking a snapshot, we can effectively ensemble many diverse models at no additional training cost.

Now normally in deployment, an ensemble deep learning model would not be used since it is computationally expensive during inference, but as an interesting exercise to see how much I can improve the model, I think it's worth a try.

To do this I will,

  1. Load the model from the epochs before each warm restart. This is possible as the model has been saved at every epoch
  2. Make a meta model, that just takes the average prediction of all the models

newplot (32).pngnewplot (33).png

To ensemble our models, I pick three checkpoints which are from before each warm restart, and load them. Their predictions are averaged and used to make a final prediction.

We can see that ensembling the three models together results in a 95% validation accuracy. This does come at a cost to computation time during inference. However, given that we are not aiming for speed here, the improved validation accuracy helps us achieve our goal of having a more robust model.

Evaluating the Final Model

The final model is as follows:

Testing Set

To ensure the model generalizes well, I evaluate it on the testing set.

As can be observed, the test set accuracy is very close to the validation set accuracy, so we can safely say that the model does not appear to be overfitting.

From the classification report, we see that

Error Analysis

I try to gain a better understanding of the examples where the model does badly.

In some of these images (cat vs dog), the low resolution makes it unclear whether the picture contains a cat or dog as both have similar shapes. So this explains why the model doesn't do as well for cats and dogs.

Summary